import pandas as pd
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import semantic_search


WN_CATS = [
    "art",
    "fashion",
    "food",
    "animals",
    "veterinary",
    "medicine",
    "archaeology",
    "artisanship",
    "commerce",
    "astronomy",
    "sports",
    "literature",
    "sexuality",
    "architecture",
    "chemistry",
    "gastronomy",
    "telecommunication",
    "military",
    "body care",
    "politics",
    "geography",
    "publishing",
    "agriculture",
    "tourism",
    "mathematics",
    "administration",
    "history",
    "physics",
    "earth science",
    "transport",
    "astrology",
    "religion",
    "law",
    "psychology",
    "philosophy",
    "pedagogy",
    "linguistics",
    "economy",
    "biology",
    "anthropology",
    "engineering",
    "industry",
    "sociology",
    "computer science",
    "children",
]


def categorize_qrys(fp_input, fp_output):
    data = pd.read_csv(fp_input)
    model = SentenceTransformer("all-MiniLM-L6-v2")
    qry_embeds = model.encode(data.qry.values)
    cat_embeds = model.encode(WN_CATS)
    cat_assigns = semantic_search(qry_embeds, cat_embeds, top_k=1)
    data["category"] = [WN_CATS[assign[0]["corpus_id"]] for assign in cat_assigns]
    data["score"] = [assign[0]["score"] for assign in cat_assigns]
    data.to_csv(fp_output)


if __name__ == "__main__":
    categorize_qrys()
